{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Action2  对Titanic数据进行清洗，建模并对乘客生存进行预测。使用之前介绍过的10种模型中的至少2种（包括TPOT）  \n",
    "1、完成代码，结果正确（20points）     \n",
    "2、能使用TPOT完成预测（20points）  "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "PassengerId 用户编号：\n",
    "记录乘客的Id编号。经过了解后：\n",
    "\n",
    "并没有查到其构成具有特别的实际意义（如身份证的构成每一位都是有实际意义的）；\n",
    "仅作为唯一标识来定位到某一乘客身上（唯一值同总数据量一样）；\n",
    "因此认为不具有分析的价值，过后也会将它进行删除处理。\n",
    "\n",
    "Survived 是否存活（label）：\n",
    "描述乘客是否存活。\n",
    "\n",
    "0 - 用户未能存活；\n",
    "1- 用户存活；\n",
    "Pclass（用户阶级）：\n",
    "描述用户所属的等级，总共分为三等，用1、2、3来描述，其中：\n",
    "\n",
    "1 - 1st class，高等用户；\n",
    "2 - 2nd class，中等用户；\n",
    "3 - 3rd class，低等用户；\n",
    "Name（名字）：\n",
    "描述乘客的全名。例如上例中的 Rugg, Miss. Emily 中：\n",
    "\n",
    "Rugg ：first name，即名；\n",
    "Miss. ：title，即称谓；\n",
    "Emily ：last name，即姓\n",
    "在登记乘客姓名时全都是用这种方法进行记录的；\n",
    "\n",
    "Sex（性别）：\n",
    "描述乘客的性别，其中：\n",
    "\n",
    "male - 男性；\n",
    "female - 女性；\n",
    "Age（年龄）：\n",
    "描述乘客的年龄，其中有部分缺失值，需要用一些手段将她们补全，具体的方法方在下面数据清洗中；\n",
    "\n",
    "SibSp 和 Parch：\n",
    "SibSp：描述了泰坦尼克号上与乘客同行的兄弟姐妹（Siblings）和配偶（Spouse）数目；\n",
    "Parch：描述了泰坦尼克号上与乘客同行的家长（Parents）和孩子（Children）数目；\n",
    "Ticket（船票号）：\n",
    "描述乘客登船所使用的船票编号。虽然它没有编码上的规律，不存在缺失值，但是唯一值可以看到，同之前唯一定位的乘客编号不同，也就是说可能会有人重复使用船票的情况，具体处理会在数据清洗中介绍，我会找到资料支撑和这一想法；\n",
    "\n",
    "Fare（乘客费用）：\n",
    "描述乘客上传所花费的费用；\n",
    "\n",
    "Cabin（船舱）：\n",
    "描述用户所住的船舱编号。由两部分组成，仓位号和房间编号，如C88中，C和88分别对应C仓位和88号房间。本字段缺失值较多，具体处理方法会在后面的数据清洗部分进行介绍。\n",
    "\n",
    "Embarked（港口）：\n",
    "描述乘客上船时的港口，包含三种类型：\n",
    "\n",
    "C：Cherbourg；\n",
    "Q：Queenstown；\n",
    "S：Southampton；"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "from sklearn.feature_extraction import DictVectorizer #特征提取\n",
    "from sklearn.model_selection import cross_val_score\n",
    "from sklearn import metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据加载\n",
    "train_data =pd.read_csv('./train.csv')\n",
    "test_data = pd.read_csv('./test.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "查看数据信息：列名、非空个数、类型等\n",
      "<class 'pandas.core.frame.DataFrame'>\n",
      "RangeIndex: 891 entries, 0 to 890\n",
      "Data columns (total 12 columns):\n",
      " #   Column       Non-Null Count  Dtype  \n",
      "---  ------       --------------  -----  \n",
      " 0   PassengerId  891 non-null    int64  \n",
      " 1   Survived     891 non-null    int64  \n",
      " 2   Pclass       891 non-null    int64  \n",
      " 3   Name         891 non-null    object \n",
      " 4   Sex          891 non-null    object \n",
      " 5   Age          714 non-null    float64\n",
      " 6   SibSp        891 non-null    int64  \n",
      " 7   Parch        891 non-null    int64  \n",
      " 8   Ticket       891 non-null    object \n",
      " 9   Fare         891 non-null    float64\n",
      " 10  Cabin        204 non-null    object \n",
      " 11  Embarked     889 non-null    object \n",
      "dtypes: float64(2), int64(5), object(5)\n",
      "memory usage: 83.7+ KB\n",
      "None\n"
     ]
    }
   ],
   "source": [
    "# 数据探索\n",
    "print('查看数据信息：列名、非空个数、类型等')\n",
    "print(train_data.info())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "查看数据摘要\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>714.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "      <td>891.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>mean</th>\n",
       "      <td>446.000000</td>\n",
       "      <td>0.383838</td>\n",
       "      <td>2.308642</td>\n",
       "      <td>29.699118</td>\n",
       "      <td>0.523008</td>\n",
       "      <td>0.381594</td>\n",
       "      <td>32.204208</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>std</th>\n",
       "      <td>257.353842</td>\n",
       "      <td>0.486592</td>\n",
       "      <td>0.836071</td>\n",
       "      <td>14.526497</td>\n",
       "      <td>1.102743</td>\n",
       "      <td>0.806057</td>\n",
       "      <td>49.693429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>min</th>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.420000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>25%</th>\n",
       "      <td>223.500000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>2.000000</td>\n",
       "      <td>20.125000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>7.910400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>50%</th>\n",
       "      <td>446.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>28.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>14.454200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>75%</th>\n",
       "      <td>668.500000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>38.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>0.000000</td>\n",
       "      <td>31.000000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>max</th>\n",
       "      <td>891.000000</td>\n",
       "      <td>1.000000</td>\n",
       "      <td>3.000000</td>\n",
       "      <td>80.000000</td>\n",
       "      <td>8.000000</td>\n",
       "      <td>6.000000</td>\n",
       "      <td>512.329200</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "       PassengerId    Survived      Pclass         Age       SibSp  \\\n",
       "count   891.000000  891.000000  891.000000  714.000000  891.000000   \n",
       "mean    446.000000    0.383838    2.308642   29.699118    0.523008   \n",
       "std     257.353842    0.486592    0.836071   14.526497    1.102743   \n",
       "min       1.000000    0.000000    1.000000    0.420000    0.000000   \n",
       "25%     223.500000    0.000000    2.000000   20.125000    0.000000   \n",
       "50%     446.000000    0.000000    3.000000   28.000000    0.000000   \n",
       "75%     668.500000    1.000000    3.000000   38.000000    1.000000   \n",
       "max     891.000000    1.000000    3.000000   80.000000    8.000000   \n",
       "\n",
       "            Parch        Fare  \n",
       "count  891.000000  891.000000  \n",
       "mean     0.381594   32.204208  \n",
       "std      0.806057   49.693429  \n",
       "min      0.000000    0.000000  \n",
       "25%      0.000000    7.910400  \n",
       "50%      0.000000   14.454200  \n",
       "75%      0.000000   31.000000  \n",
       "max      6.000000  512.329200  "
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('查看数据摘要')\n",
    "train_data.describe()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "查看离散数据分布\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>count</th>\n",
       "      <td>891</td>\n",
       "      <td>891</td>\n",
       "      <td>891</td>\n",
       "      <td>204</td>\n",
       "      <td>889</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>unique</th>\n",
       "      <td>891</td>\n",
       "      <td>2</td>\n",
       "      <td>681</td>\n",
       "      <td>147</td>\n",
       "      <td>3</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>top</th>\n",
       "      <td>Bryhl, Mr. Kurt Arnold Gottfrid</td>\n",
       "      <td>male</td>\n",
       "      <td>CA. 2343</td>\n",
       "      <td>G6</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>freq</th>\n",
       "      <td>1</td>\n",
       "      <td>577</td>\n",
       "      <td>7</td>\n",
       "      <td>4</td>\n",
       "      <td>644</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                   Name   Sex    Ticket Cabin Embarked\n",
       "count                               891   891       891   204      889\n",
       "unique                              891     2       681   147        3\n",
       "top     Bryhl, Mr. Kurt Arnold Gottfrid  male  CA. 2343    G6        S\n",
       "freq                                  1   577         7     4      644"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "print('查看离散数据分布')\n",
    "train_data.describe(include=['O'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>PassengerId</th>\n",
       "      <th>Survived</th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Name</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Ticket</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Cabin</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Braund, Mr. Owen Harris</td>\n",
       "      <td>male</td>\n",
       "      <td>22.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>A/5 21171</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>2</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Cumings, Mrs. John Bradley (Florence Briggs Th...</td>\n",
       "      <td>female</td>\n",
       "      <td>38.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>PC 17599</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C85</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>Heikkinen, Miss. Laina</td>\n",
       "      <td>female</td>\n",
       "      <td>26.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>STON/O2. 3101282</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>4</td>\n",
       "      <td>1</td>\n",
       "      <td>1</td>\n",
       "      <td>Futrelle, Mrs. Jacques Heath (Lily May Peel)</td>\n",
       "      <td>female</td>\n",
       "      <td>35.0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>113803</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>C123</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>5</td>\n",
       "      <td>0</td>\n",
       "      <td>3</td>\n",
       "      <td>Allen, Mr. William Henry</td>\n",
       "      <td>male</td>\n",
       "      <td>35.0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>373450</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>NaN</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   PassengerId  Survived  Pclass  \\\n",
       "0            1         0       3   \n",
       "1            2         1       1   \n",
       "2            3         1       3   \n",
       "3            4         1       1   \n",
       "4            5         0       3   \n",
       "\n",
       "                                                Name     Sex   Age  SibSp  \\\n",
       "0                            Braund, Mr. Owen Harris    male  22.0      1   \n",
       "1  Cumings, Mrs. John Bradley (Florence Briggs Th...  female  38.0      1   \n",
       "2                             Heikkinen, Miss. Laina  female  26.0      0   \n",
       "3       Futrelle, Mrs. Jacques Heath (Lily May Peel)  female  35.0      1   \n",
       "4                           Allen, Mr. William Henry    male  35.0      0   \n",
       "\n",
       "   Parch            Ticket     Fare Cabin Embarked  \n",
       "0      0         A/5 21171   7.2500   NaN        S  \n",
       "1      0          PC 17599  71.2833   C85        C  \n",
       "2      0  STON/O2. 3101282   7.9250   NaN        S  \n",
       "3      0            113803  53.1000  C123        S  \n",
       "4      0            373450   8.0500   NaN        S  "
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_data.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用平均年龄来填充年龄中的nan值\n",
    "train_data['Age'].fillna(train_data['Age'].mean(), inplace=True)\n",
    "test_data['Age'].fillna(test_data['Age'].mean(),inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 使用票价的均值填充票价中的nan值\n",
    "train_data['Fare'].fillna(train_data['Fare'].mean(), inplace=True)\n",
    "test_data['Fare'].fillna(test_data['Fare'].mean(),inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "S    644\n",
      "C    168\n",
      "Q     77\n",
      "Name: Embarked, dtype: int64\n"
     ]
    }
   ],
   "source": [
    "print(train_data['Embarked'].value_counts())\n",
    "# 使用登录最多的港口来填充登录港口的nan值\n",
    "train_data['Embarked'].fillna('S', inplace=True)\n",
    "test_data['Embarked'].fillna('S',inplace=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>Pclass</th>\n",
       "      <th>Sex</th>\n",
       "      <th>Age</th>\n",
       "      <th>SibSp</th>\n",
       "      <th>Parch</th>\n",
       "      <th>Fare</th>\n",
       "      <th>Embarked</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>22.000000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>7.2500</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>38.000000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>71.2833</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.9250</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>53.1000</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>35.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>8.0500</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>...</th>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "      <td>...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>886</th>\n",
       "      <td>2</td>\n",
       "      <td>male</td>\n",
       "      <td>27.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>13.0000</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>887</th>\n",
       "      <td>1</td>\n",
       "      <td>female</td>\n",
       "      <td>19.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>888</th>\n",
       "      <td>3</td>\n",
       "      <td>female</td>\n",
       "      <td>29.699118</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>23.4500</td>\n",
       "      <td>S</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>889</th>\n",
       "      <td>1</td>\n",
       "      <td>male</td>\n",
       "      <td>26.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>30.0000</td>\n",
       "      <td>C</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>890</th>\n",
       "      <td>3</td>\n",
       "      <td>male</td>\n",
       "      <td>32.000000</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>7.7500</td>\n",
       "      <td>Q</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "<p>891 rows × 7 columns</p>\n",
       "</div>"
      ],
      "text/plain": [
       "     Pclass     Sex        Age  SibSp  Parch     Fare Embarked\n",
       "0         3    male  22.000000      1      0   7.2500        S\n",
       "1         1  female  38.000000      1      0  71.2833        C\n",
       "2         3  female  26.000000      0      0   7.9250        S\n",
       "3         1  female  35.000000      1      0  53.1000        S\n",
       "4         3    male  35.000000      0      0   8.0500        S\n",
       "..      ...     ...        ...    ...    ...      ...      ...\n",
       "886       2    male  27.000000      0      0  13.0000        S\n",
       "887       1  female  19.000000      0      0  30.0000        S\n",
       "888       3  female  29.699118      1      2  23.4500        S\n",
       "889       1    male  26.000000      0      0  30.0000        C\n",
       "890       3    male  32.000000      0      0   7.7500        Q\n",
       "\n",
       "[891 rows x 7 columns]"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 特征选择\n",
    "features = ['Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']\n",
    "train_features = train_data[features]\n",
    "train_labels = train_data['Survived']\n",
    "test_features = test_data[features]\n",
    "train_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[22.        ,  0.        ,  0.        , ...,  0.        ,\n",
       "         1.        ,  1.        ],\n",
       "       [38.        ,  1.        ,  0.        , ...,  1.        ,\n",
       "         0.        ,  1.        ],\n",
       "       [26.        ,  0.        ,  0.        , ...,  1.        ,\n",
       "         0.        ,  0.        ],\n",
       "       ...,\n",
       "       [29.69911765,  0.        ,  0.        , ...,  1.        ,\n",
       "         0.        ,  1.        ],\n",
       "       [26.        ,  1.        ,  0.        , ...,  0.        ,\n",
       "         1.        ,  0.        ],\n",
       "       [32.        ,  0.        ,  1.        , ...,  0.        ,\n",
       "         1.        ,  0.        ]])"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 特征提取\n",
    "# 'records' : list like [{column -> value}, ... , {column -> value}]\n",
    "dvec=DictVectorizer(sparse=False)\n",
    "train_features = dvec.fit_transform(train_features.to_dict(orient='record'))\n",
    "train_features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Age',\n",
       " 'Embarked=C',\n",
       " 'Embarked=Q',\n",
       " 'Embarked=S',\n",
       " 'Fare',\n",
       " 'Parch',\n",
       " 'Pclass',\n",
       " 'Sex=female',\n",
       " 'Sex=male',\n",
       " 'SibSp']"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# dvec.get_feature_names()\n",
    "dvec.feature_names_"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DecisionTreeClassifier(criterion='entropy')"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 构造ID3决策树\n",
    "clf = DecisionTreeClassifier(criterion='entropy')\n",
    "clf.fit(train_features,train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 测试集特征提取转换（已经fit过了不需要再fit）\n",
    "test_features = dvec.transform(test_features.to_dict(orient='record'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "score准确率为0.9820\n"
     ]
    }
   ],
   "source": [
    "# 预测\n",
    "pred_labels = clf.predict(test_features)\n",
    "acc_decision_tree = round(clf.score(train_features,train_labels),6)\n",
    "print('score准确率为%.4lf'%acc_decision_tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cross_val_score准确率为0.7779\n"
     ]
    }
   ],
   "source": [
    "# 使用k折交叉验证统计决策树准确率\n",
    "print('cross_val_score准确率为%.4lf'% np.mean(cross_val_score(clf,train_features,train_labels,cv=10)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Optimization Progress:  18%|█████████▍                                          | 40/220 [00:18<02:10,  1.38pipeline/s]\n",
      "Optimization Progress:  27%|██████████████▏                                     | 60/220 [00:32<02:06,  1.27pipeline/s]\n",
      "Optimization Progress:  36%|██████████████████▋                                 | 79/220 [00:53<02:24,  1.03s/pipeline]\n",
      "Optimization Progress:  45%|███████████████████████▏                           | 100/220 [01:24<04:33,  2.28s/pipeline]\n",
      "Optimization Progress:  55%|███████████████████████████▊                       | 120/220 [01:48<02:40,  1.61s/pipeline]\n",
      "Optimization Progress:  63%|███████████████████████████████▉                   | 138/220 [02:05<01:08,  1.20pipeline/s]\n",
      "Optimization Progress:  73%|█████████████████████████████████████              | 160/220 [02:35<01:11,  1.19s/pipeline]\n",
      "Optimization Progress:  82%|█████████████████████████████████████████▋         | 180/220 [03:20<01:49,  2.74s/pipeline]\n",
      "Optimization Progress:  91%|██████████████████████████████████████████████▎    | 200/220 [03:51<00:37,  1.87s/pipeline]\n",
      "Optimization Progress: 100%|███████████████████████████████████████████████████| 220/220 [04:18<00:00,  1.33s/pipeline]\n",
      "                                                                                                                       \n",
      "Best pipeline: RandomForestClassifier(XGBClassifier(SelectFromModel(input_matrix, criterion=gini, max_features=1.0, n_estimators=100, threshold=0.0), learning_rate=0.5, max_depth=6, min_child_weight=16, n_estimators=100, nthread=1, subsample=0.15000000000000002), bootstrap=False, criterion=gini, max_features=0.4, min_samples_leaf=4, min_samples_split=11, n_estimators=100)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TPOTClassifier(generations=10,\n",
       "               log_file=<ipykernel.iostream.OutStream object at 0x00000233C84D4D60>,\n",
       "               population_size=20, verbosity=2)"
      ]
     },
     "execution_count": 52,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 使用TPOT自动机器学习工具对titanic进行分类\n",
    "from tpot import TPOTClassifier\n",
    "from sklearn.model_selection import train_test_split\n",
    "import numpy as np\n",
    "\n",
    "tpot = TPOTClassifier(generations=10, population_size=20, verbosity=2)\n",
    "tpot.fit(train_features,train_labels)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tpot准确率为0.8945\n"
     ]
    }
   ],
   "source": [
    "print('tpot准确率为%.4lf'%tpot.score(train_features,train_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cart_model的score准确率为0.9820\n"
     ]
    }
   ],
   "source": [
    "# 创建cart分类器\n",
    "from sklearn.tree import DecisionTreeClassifier\n",
    "cart_model = DecisionTreeClassifier(criterion='gini')\n",
    "cart_model.fit(train_features,train_labels)\n",
    "pred_labels = cart_model.predict(test_features)\n",
    "acc_decision_tree = round(cart_model.score(train_features,train_labels),6)\n",
    "print('cart_model的score准确率为%.4lf'%acc_decision_tree)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "LR准确率: 0.7991\n"
     ]
    }
   ],
   "source": [
    "# 创建LR分类器\n",
    "from sklearn.linear_model import LogisticRegression\n",
    "lr = LogisticRegression()\n",
    "lr.fit(train_features,train_labels)\n",
    "predict_y=lr.predict(test_features)\n",
    "print('LR准确率: %0.4lf' % lr.score(train_features,train_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "RF准确率: 0.7991\n"
     ]
    }
   ],
   "source": [
    "from sklearn.ensemble import RandomForestClassifier \n",
    "RF = LogisticRegression()\n",
    "RF.fit(train_features,train_labels)\n",
    "predict_y=RF.predict(test_features)\n",
    "print('RF准确率: %0.4lf' % RF.score(train_features,train_labels))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.8.0 64-bit ('Bi_env': venv)",
   "language": "python",
   "name": "python38064bitbienvvenvba07af95a1bb4b078aa8134bba84dff2"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.0"
  },
  "toc-autonumbering": false
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
